{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import LeNet5_infernece\n",
    "import os\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "#### 1. 定义神经网络相关的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 100\n",
    "LEARNING_RATE_BASE = 0.01\n",
    "LEARNING_RATE_DECAY = 0.99\n",
    "REGULARIZATION_RATE = 0.0001\n",
    "TRAINING_STEPS = 6000\n",
    "MOVING_AVERAGE_DECAY = 0.99"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 定义训练过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(mnist):\n",
    "    # 定义输出为4维矩阵的placeholder\n",
    "    x = tf.placeholder(tf.float32, [\n",
    "            BATCH_SIZE,\n",
    "            LeNet5_infernece.IMAGE_SIZE,\n",
    "            LeNet5_infernece.IMAGE_SIZE,\n",
    "            LeNet5_infernece.NUM_CHANNELS],\n",
    "        name='x-input')\n",
    "    y_ = tf.placeholder(tf.float32, [None, LeNet5_infernece.OUTPUT_NODE], name='y-input')\n",
    "    \n",
    "    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)\n",
    "    y = LeNet5_infernece.inference(x,False,regularizer)\n",
    "    global_step = tf.Variable(0, trainable=False)\n",
    "\n",
    "    # 定义损失函数、学习率、滑动平均操作以及训练过程。\n",
    "    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)\n",
    "    variables_averages_op = variable_averages.apply(tf.trainable_variables())\n",
    "    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))\n",
    "    cross_entropy_mean = tf.reduce_mean(cross_entropy)\n",
    "    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))\n",
    "    learning_rate = tf.train.exponential_decay(\n",
    "        LEARNING_RATE_BASE,\n",
    "        global_step,\n",
    "        mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,\n",
    "        staircase=True)\n",
    "\n",
    "    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)\n",
    "    with tf.control_dependencies([train_step, variables_averages_op]):\n",
    "        train_op = tf.no_op(name='train')\n",
    "        \n",
    "    # 初始化TensorFlow持久化类。\n",
    "    saver = tf.train.Saver()\n",
    "    with tf.Session() as sess:\n",
    "        tf.global_variables_initializer().run()\n",
    "        for i in range(TRAINING_STEPS):\n",
    "            xs, ys = mnist.train.next_batch(BATCH_SIZE)\n",
    "\n",
    "            reshaped_xs = np.reshape(xs, (\n",
    "                BATCH_SIZE,\n",
    "                LeNet5_infernece.IMAGE_SIZE,\n",
    "                LeNet5_infernece.IMAGE_SIZE,\n",
    "                LeNet5_infernece.NUM_CHANNELS))\n",
    "            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})\n",
    "\n",
    "            if i % 50 == 0:\n",
    "                print(\"After %d training step(s), loss on training batch is %g.\" % (step, loss_value))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 主程序入口"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../../datasets/MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../datasets/MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "After 1 training step(s), loss on training batch is 4.26947.\n",
      "After 101 training step(s), loss on training batch is 1.01421.\n",
      "After 201 training step(s), loss on training batch is 0.928992.\n",
      "After 301 training step(s), loss on training batch is 0.810302.\n"
     ]
    }
   ],
   "source": [
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"../../datasets/MNIST_data\", one_hot=True)\n",
    "    train(mnist)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
